#! /usr/bin/env python3
"""
@author  : MG
@Time    : 2021/10/11 8:53
@File    : to_vnpy.py
@contact : mmmaaaggg@163.com
@desc    : 用于
"""
import logging

import pandas as pd
from ibats_utils.db import with_db_session
from ibats_utils.mess import datetime_2_str
from sqlalchemy.dialects.mysql import DOUBLE
from sqlalchemy.exc import OperationalError
from sqlalchemy.types import String, Date

from tasks import app
from tasks import config
from tasks.backend import engine_md, engine_dic

logger = logging.getLogger()


def _reversion_rights_factors_2_vnpy(table_name='wind_future_adj_factor', db_schema_name=config.DB_SCHEMA_VNPY):
    from tasks.wind.future_reorg.reversion_rights_factor import backup_to_db
    logger.info(f"开始向 {db_schema_name} 同步 {table_name} 数据")
    dtype = {
        'trade_date': Date,
        'instrument_id_main': String(20),
        'adj_factor_main': DOUBLE,
        'instrument_id_secondary': String(20),
        'adj_factor_secondary': DOUBLE,
        'instrument_type': String(20),
        'method': String(20),
    }
    engine_vnpy = engine_dic[db_schema_name]
    try:
        backup_to_db(table_name=table_name, new_engine=engine_vnpy, dtype=dtype)
    except OperationalError:
        logger.exception("链接数据库失败")


def reversion_rights_factors_2_vnpy(table_name='wind_future_adj_factor'):
    _reversion_rights_factors_2_vnpy(table_name=table_name, db_schema_name=config.DB_SCHEMA_VNPY)
    # _reversion_rights_factors_2_vnpy(table_name=table_name, db_schema_name=config.DB_SCHEMA_VNPY_PA)


def _wind_future_daily_2_vnpy(
        db_schema_name=config.DB_SCHEMA_VNPY, instrument_types=None):
    from tasks.wind import WIND_VNPY_EXCHANGE_DIC, get_wind_code_list_by_types
    table_name = 'dbbardata'
    interval = '1d'
    logger.info(f"开始向 {db_schema_name} 同步 {table_name} {interval} 数据")
    engine_vnpy = engine_dic[db_schema_name]
    try:
        has_table = engine_vnpy.has_table(table_name)
    except OperationalError:
        logger.exception("链接数据库失败")
        return
    if not has_table:
        logger.error('当前数据库 %s 没有 %s 表，建议使用 vnpy先建立相应的数据库表后再进行导入操作', engine_vnpy, table_name)
        return

    wind_code_list = get_wind_code_list_by_types(instrument_types)
    wind_code_count = len(wind_code_list)
    for n, wind_code in enumerate(wind_code_list, start=1):
        symbol, exchange = wind_code.split('.')
        if exchange in WIND_VNPY_EXCHANGE_DIC:
            exchange_vnpy = WIND_VNPY_EXCHANGE_DIC[exchange]
        else:
            logger.warning('exchange: %s 在交易所列表中不存在', exchange)
            exchange_vnpy = exchange

        # 读取日线数据
        sql_str = "select trade_date `datetime`, `open` open_price, high high_price, " \
                  "`low` low_price, `close` close_price, volume, position as open_interest " \
                  "from wind_future_daily where wind_code = %s and `close` <> 0"
        df = pd.read_sql(sql_str, engine_md, params=[wind_code]).dropna()
        df_len = df.shape[0]
        if df_len == 0:
            continue

        df['symbol'] = symbol
        df['exchange'] = exchange_vnpy
        df['interval'] = interval

        sql_str = f"select count(1) from {table_name} where symbol=:symbol and `interval`='1d'"
        del_sql_str = f"delete from {table_name} where symbol=:symbol and `interval`='1d'"
        with with_db_session(engine_vnpy) as session:
            existed_count = session.scalar(sql_str, params={'symbol': symbol})
            if existed_count == df_len:
                continue
            if existed_count > 0:
                session.execute(del_sql_str, params={'symbol': symbol})
                session.commit()

        df.to_sql(table_name, engine_vnpy, if_exists='append', index=False)
        logger.info("%d/%d) %s %d data -> %s interval %s",
                    n, wind_code_count, symbol, df.shape[0], table_name, interval)


@app.task
def wind_future_daily_2_vnpy(
        chain_param=None, instrument_types=None, start_vnp=True):
    _wind_future_daily_2_vnpy(db_schema_name=config.DB_SCHEMA_VNPY, instrument_types=instrument_types)

    # if start_vnp:
    #     from tasks.utils.vpn import restart_vpn
    #     restart_vpn()
    #
    # _wind_future_daily_2_vnpy(db_schema_name=config.DB_SCHEMA_VNPY_PA, instrument_types=instrument_types)


def _min_to_vnpy_increment(db_schema_name=config.DB_SCHEMA_VNPY, instrument_types=None):
    from tasks.backend import engine_dic
    from tasks.wind import get_wind_code_list_by_types, WIND_VNPY_EXCHANGE_DIC
    table_name = 'dbbardata'
    interval = '1m'
    logger.info(f"开始向 {db_schema_name} 同步 {table_name} {interval} 数据")
    engine_vnpy = engine_dic[db_schema_name]
    try:
        has_table = engine_vnpy.has_table(table_name)
    except OperationalError:
        logger.exception("链接数据库失败")
        return
    if not has_table:
        logger.error('当前数据库 %s 没有 %s 表，建议使用 vnpy先建立相应的数据库表后再进行导入操作', engine_vnpy, table_name)
        return

    sql_increment_str = "select trade_datetime `datetime`, `open` open_price, high high_price, " \
                        "`low` low_price, `close` close_price, volume, position as open_interest " \
                        "from wind_future_min where wind_code = %s and " \
                        "trade_datetime > %s and `close` is not null and `close` <> 0"
    sql_whole_str = "select trade_datetime `datetime`, `open` open_price, high high_price, " \
                    "`low` low_price, `close` close_price, volume, position as open_interest " \
                    "from wind_future_min where wind_code = %s and " \
                    "`close` is not null and `close` <> 0"
    wind_code_list = get_wind_code_list_by_types(instrument_types)
    wind_code_count = len(wind_code_list)
    for n, wind_code in enumerate(wind_code_list, start=1):
        symbol, exchange = wind_code.split('.')
        if exchange in WIND_VNPY_EXCHANGE_DIC:
            exchange_vnpy = WIND_VNPY_EXCHANGE_DIC[exchange]
        else:
            logger.warning('%s exchange: %s 在交易所列表中不存在', wind_code, exchange)
            exchange_vnpy = exchange
        sql_str = f"select max(`datetime`) from {table_name} where symbol=:symbol and `interval`=:interval"
        with with_db_session(engine_vnpy) as session:
            datetime_exist = session.scalar(sql_str, params={'symbol': symbol, 'interval': interval})
        if datetime_exist is not None:
            # 读取日线数据
            df = pd.read_sql(sql_increment_str, engine_md, params=[wind_code, datetime_exist]).dropna()
        else:
            df = pd.read_sql(sql_whole_str, engine_md, params=[wind_code]).dropna()

        df_len = df.shape[0]
        if df_len == 0:
            continue

        df['symbol'] = symbol
        df['exchange'] = exchange_vnpy
        df['interval'] = interval
        datetime_latest = df['datetime'].max().to_pydatetime()
        df.to_sql(table_name, engine_vnpy, if_exists='append', index=False)
        logger.info("%d/%d) %s (%s ~ %s] %d data -> %s interval %s",
                    n, wind_code_count, symbol,
                    datetime_2_str(datetime_exist), datetime_2_str(datetime_latest),
                    df_len, table_name, interval)


@app.task
def min_to_vnpy_increment(
        chain_param=None, instrument_types=None, start_vnp=False):
    _min_to_vnpy_increment(instrument_types=instrument_types, db_schema_name=config.DB_SCHEMA_VNPY)
    # if start_vnp:
    #     from tasks.utils.vpn import restart_vpn
    #     restart_vpn()
    #
    # _min_to_vnpy_increment(instrument_types=instrument_types, db_schema_name=config.DB_SCHEMA_VNPY_PA)


def to_vnpy_pa(start_vnp=True, db_schema_name=config.DB_SCHEMA_VNPY_PA):
    if start_vnp:
        from tasks.utils.vpn import restart_vpn
        restart_vpn()

    _reversion_rights_factors_2_vnpy(db_schema_name=db_schema_name)
    instrument_types = None
    _wind_future_daily_2_vnpy(db_schema_name=db_schema_name, instrument_types=instrument_types)
    _min_to_vnpy_increment(db_schema_name=db_schema_name)


if __name__ == "__main__":
    # reversion_rights_factors_2_vnpy(start_vnp=True)
    # wind_future_daily_2_vnpy(start_vnp=True)
    # min_to_vnpy_increment(start_vnp=True)
    to_vnpy_pa(start_vnp=True)
